# test_security.py - Security-focused tests for the Evaluations module
"""
Security test suite for the Evaluations module.

Tests for:
- Path traversal protection
- SQL injection prevention
- SSRF protection in webhooks
- Input validation
- Score parsing security
- Thread safety
"""

import pytest
import tempfile
import sqlite3
import os
import sys
import asyncio
from pathlib import Path
from unittest.mock import patch, MagicMock, Mock
from concurrent.futures import ThreadPoolExecutor
import threading
import queue
import time

# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))

from tldw_Server_API.app.core.Evaluations.evaluation_manager import EvaluationManager
from tldw_Server_API.app.core.DB_Management.db_path_utils import DatabasePaths
from tldw_Server_API.app.core.Evaluations.webhook_security import WebhookSecurityValidator
from tldw_Server_API.app.core.Evaluations.connection_pool import ConnectionPool


class TestPathTraversalProtection:
    """Test path traversal attack prevention"""

    def test_path_traversal_in_config(self):
        """Test that path traversal attempts in config are blocked"""
        manager = EvaluationManager()
        expected_path = DatabasePaths.get_evaluations_db_path(DatabasePaths.get_single_user_id()).resolve()

        # Mock config with path traversal attempt
        with patch.object(manager, 'config') as mock_config:
            mock_config.has_section.return_value = True
            mock_config.get.return_value = "../../../../../../etc/passwd"

            # Re-get the path with malicious config
            safe_path = manager._get_db_path()

            # Should not traverse outside project directory
            assert "../" not in str(safe_path)
            assert "etc/passwd" not in str(safe_path)
            assert safe_path == expected_path

    def test_absolute_path_outside_project(self):
        """Test that absolute paths outside project are rejected"""
        manager = EvaluationManager()
        expected_path = DatabasePaths.get_evaluations_db_path(DatabasePaths.get_single_user_id()).resolve()

        with patch.object(manager, 'config') as mock_config:
            mock_config.has_section.return_value = True
            mock_config.get.return_value = "/etc/sensitive/database.db"

            safe_path = manager._get_db_path()

            # Should use safe default path instead
            assert "/etc/sensitive" not in str(safe_path)
            assert safe_path == expected_path

    def test_null_byte_injection(self):
        """Test that null byte injection is handled"""
        manager = EvaluationManager()

        with patch.object(manager, 'config') as mock_config:
            mock_config.has_section.return_value = True
            mock_config.get.return_value = "database.db\x00.txt"

            safe_path = manager._get_db_path()

            # Null bytes should be handled safely
            assert "\x00" not in str(safe_path)


class TestSQLInjectionPrevention:
    """Test SQL injection attack prevention"""

    @pytest.mark.asyncio
    async def test_sql_injection_in_evaluation_id(self):
        """Test that SQL injection in evaluation IDs is prevented"""
        manager = EvaluationManager()

        # Try SQL injection in evaluation storage
        malicious_input = {
            "test": "'; DROP TABLE internal_evaluations; --"
        }

        # This should not execute the DROP TABLE command
        eval_id = await manager.store_evaluation(
            evaluation_type="test",
            input_data=malicious_input,
            results={"score": 0.5}
        )

        # Verify table still exists
        with sqlite3.connect(manager.db_path) as conn:
            cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='internal_evaluations'")
            assert cursor.fetchone() is not None

    @pytest.mark.asyncio
    async def test_sql_injection_in_history_query(self):
        """Test SQL injection prevention in history queries"""
        manager = EvaluationManager()

        # Try SQL injection in evaluation type filter
        malicious_type = "test' OR 1=1 --"

        # This should safely handle the malicious input
        history = await manager.get_history(
            evaluation_type=malicious_type,
            limit=10
        )

        # Should return empty or filtered results, not all results
        assert history is not None
        # The query should be properly parameterized


class TestWebhookSecurityValidation:
    """Test webhook security validation"""

    @pytest.mark.asyncio
    async def test_ssrf_protection_private_ip(self):
        """Test that private IPs are blocked"""
        validator = WebhookSecurityValidator()

        # Test various private IP addresses
        private_urls = [
            "http://127.0.0.1/webhook",
            "http://localhost/webhook",
            "http://192.168.1.1/webhook",
            "http://10.0.0.1/webhook",
            "http://172.16.0.1/webhook",
            "http://[::1]/webhook",
            "http://169.254.169.254/metadata"  # AWS metadata endpoint
        ]

        for url in private_urls:
            result = await validator.validate_webhook_url(
                url=url,
                user_id="test_user",
                check_connectivity=False
            )

            # Should have errors for private IPs
            assert not result.valid
            assert any(
                e.code in ["PRIVATE_NETWORK", "PRIVATE_IP", "PRIVATE_NETWORK_WARNING"]
                for e in result.errors + result.warnings
            )

    @pytest.mark.asyncio
    async def test_ssrf_dns_rebinding(self):
        """Test protection against DNS rebinding attacks"""
        validator = WebhookSecurityValidator()

        # Mock DNS resolution that changes between checks
        with patch('socket.getaddrinfo') as mock_getaddrinfo:
            # First resolution returns public IP, second returns private
            mock_getaddrinfo.side_effect = [
                [(None, None, None, None, ('8.8.8.8',))],  # Initial check
                [(None, None, None, None, ('192.168.1.1',))]  # Redirect check
            ]

            result = await validator.validate_webhook_url(
                url="http://attacker.com/webhook",
                user_id="test_user",
                check_connectivity=False
            )

            # Should detect the DNS rebinding attempt
            # Note: This requires implementation of redirect validation

    @pytest.mark.asyncio
    async def test_blocked_ports(self):
        """Test that sensitive ports are blocked"""
        validator = WebhookSecurityValidator()

        blocked_port_urls = [
            "http://example.com:22/webhook",   # SSH
            "http://example.com:3306/webhook",  # MySQL
            "http://example.com:5432/webhook",  # PostgreSQL
            "http://example.com:6379/webhook",  # Redis
        ]

        for url in blocked_port_urls:
            result = await validator.validate_webhook_url(
                url=url,
                user_id="test_user",
                check_connectivity=False
            )

            assert not result.valid
            assert any(e.code == "BLOCKED_PORT" for e in result.errors)


class TestScoreParsingSecurity:
    """Test secure score parsing"""

    @pytest.mark.asyncio
    async def test_malicious_score_injection(self):
        """Test that malicious score inputs are handled safely"""
        manager = EvaluationManager()

        # Mock the analyze function to return malicious responses
        malicious_responses = [
            "999999999999999999999999",  # Huge number
            "-999999",  # Negative number
            "Score: <script>alert('xss')</script>",  # XSS attempt
            '{"score": "\'; DROP TABLE evaluations; --"}',  # SQL injection in JSON
            "Score: NaN",  # Not a number
            "Score: Infinity",  # Infinity
            "Score: 1e308",  # Very large number
        ]

        for response in malicious_responses:
            with patch('tldw_Server_API.app.core.Evaluations.evaluation_manager.analyze') as mock_analyze:
                mock_analyze.return_value = response

                result = await manager.evaluate_custom_metric(
                    metric_name="test",
                    description="test",
                    evaluation_prompt="test",
                    input_data={"test": "data"},
                    scoring_criteria={"test": "criteria"}
                )

                # Score should be safely bounded between 0 and 1
                assert 0 <= result["score"] <= 1
                assert isinstance(result["score"], float)

    @pytest.mark.asyncio
    async def test_json_parsing_security(self):
        """Test that JSON parsing is secure"""
        manager = EvaluationManager()

        # Test various JSON edge cases
        json_responses = [
            '{"score": 8, "__proto__": {"isAdmin": true}}',  # Prototype pollution attempt
            '{"score": 7, "constructor": {"prototype": {"isAdmin": true}}}',  # Constructor attack
            '{"score": [1,2,3,4,5,6,7,8,9,10]}',  # Array instead of number
            '{"score": {"value": 8}}',  # Nested object
        ]

        for response in json_responses:
            with patch('tldw_Server_API.app.core.Evaluations.evaluation_manager.analyze') as mock_analyze:
                mock_analyze.return_value = response

                result = await manager.evaluate_custom_metric(
                    metric_name="test",
                    description="test",
                    evaluation_prompt="test",
                    input_data={"test": "data"},
                    scoring_criteria={"test": "criteria"}
                )

                # Should handle gracefully
                assert 0 <= result["score"] <= 1


class TestConnectionPoolThreadSafety:
    """Test connection pool thread safety"""

    def test_concurrent_connection_access(self):
        """Test that connections are thread-safe"""
        with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp:
            pool = ConnectionPool(
                db_path=tmp.name,
                pool_size=5,
                max_overflow=10
            )

            errors = []
            results = []

            def worker(worker_id):
                """Worker function that uses connections"""
                try:
                    for i in range(10):
                        with pool.get_connection() as conn:
                            # Each thread writes to its own table
                            table_name = f"test_table_{worker_id}"
                            conn.execute(f"CREATE TABLE IF NOT EXISTS {table_name} (id INTEGER, value TEXT)")
                            conn.execute(f"INSERT INTO {table_name} VALUES (?, ?)", (i, f"value_{i}"))
                            conn.commit()

                            # Verify write
                            cursor = conn.execute(f"SELECT COUNT(*) FROM {table_name}")
                            count = cursor.fetchone()[0]
                            results.append((worker_id, count))

                except Exception as e:
                    errors.append((worker_id, str(e)))

            # Run multiple threads concurrently
            threads = []
            for i in range(10):
                t = threading.Thread(target=worker, args=(i,))
                threads.append(t)
                t.start()

            # Wait for all threads
            for t in threads:
                t.join()

            # Clean up
            pool.shutdown()
            os.unlink(tmp.name)

            # Check results
            assert len(errors) == 0, f"Thread safety errors: {errors}"

            # Each worker should have successfully written its data
            for worker_id in range(10):
                worker_results = [r for r in results if r[0] == worker_id]
                assert len(worker_results) == 10
                # Counts should increase from 1 to 10
                counts = [r[1] for r in worker_results]
                assert counts == list(range(1, 11))

    def test_connection_pool_exhaustion(self):
        """Test that pool exhaustion is handled correctly"""
        import time
        with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp:
            # Create a small pool
            pool = ConnectionPool(
                db_path=tmp.name,
                pool_size=2,
                max_overflow=1,
                pool_timeout=1.0  # Short timeout for testing
            )

            connections = []
            contexts = []
            try:
                # Get all available connections using context managers properly
                for i in range(3):
                    ctx = pool.get_connection()
                    conn = ctx.__enter__()
                    connections.append(conn)
                    contexts.append(ctx)

                # Try to get one more - should timeout or raise appropriate exception
                # The pool should either raise TimeoutError or queue.Empty
                start = time.time()
                try:
                    with pool.get_connection() as conn:
                        # If we get here, the pool didn't enforce its limits properly
                        assert False, "Pool should have been exhausted"
                except (TimeoutError, queue.Empty, RuntimeError) as e:
                    # Expected - pool is exhausted
                    elapsed = time.time() - start
                    # Should timeout quickly (within 2 seconds)
                    assert elapsed <= 2.0, f"Timeout took too long: {elapsed:.2f}s"

            finally:
                # Properly exit all context managers
                for ctx in reversed(contexts):
                    try:
                        ctx.__exit__(None, None, None)
                    except:
                        pass

                pool.shutdown()
                os.unlink(tmp.name)


class TestInputValidation:
    """Test input validation across the module"""

    @pytest.mark.asyncio
    async def test_evaluation_type_validation(self):
        """Test that evaluation types are validated"""
        manager = EvaluationManager()

        # Test with various invalid types
        invalid_types = [
            "",  # Empty string
            None,  # None
            "x" * 1000,  # Very long string
            "../../etc/passwd",  # Path traversal attempt
            "'; DROP TABLE --",  # SQL injection
            "\x00null\x00",  # Null bytes
        ]

        for invalid_type in invalid_types:
            # Should handle gracefully without crashes
            try:
                result = await manager.store_evaluation(
                    evaluation_type=invalid_type or "test",
                    input_data={"test": "data"},
                    results={"score": 0.5}
                )
                # If it succeeds, the value should be sanitized
                assert result is not None
            except (ValueError, TypeError):
                # Expected for None and other invalid types
                pass

    @pytest.mark.asyncio
    async def test_metadata_size_limits(self):
        """Test that large metadata is handled properly"""
        manager = EvaluationManager()

        # Create very large metadata
        large_metadata = {
            "huge_field": "x" * (10 * 1024 * 1024),  # 10MB string
            "nested": {str(i): "value" for i in range(10000)}  # Many keys
        }

        # Should handle large data gracefully
        # This might truncate or reject, but shouldn't crash
        eval_id = await manager.store_evaluation(
            evaluation_type="test",
            input_data={"test": "data"},
            results={"score": 0.5},
            metadata=large_metadata
        )

        # Verify it was stored (possibly truncated)
        assert eval_id is not None


@pytest.mark.asyncio
async def test_rate_limiting():
    """Test that rate limiting works correctly"""
    # This would require the API endpoints
    # Leaving as placeholder for API-level testing
    pass


if __name__ == "__main__":
    pytest.main([__file__, "-v"])
